import os
import openai
import json
from typing import Any, List, Dict, Tuple
from pipelines.prompta.utils import query2str, get_value_by_key, tuple2word, word2tuple
from prompta.core.language import BaseLanguage
from .base_oracle import BaseOracle
from .base_llm_oracle import BaseLLMOracle
from .oracle_wrapper import TemporaryMembershipOracleWrapper
from pipelines.prompta.learner.java_utils.dfa import TTTLearnerDFA
from prompta.utils.java_libs import AcexAnalyzers


def levenshtein_distance_tuple(tuple1, tuple2):
    """    Example usage:
        tuple1 = ('aaa', 'a', 'bbb')
        tuple2 = ('a', 'aaa', 'bbb')
        distance = levenshtein_distance_tuple(tuple1, tuple2)
        print("Levenshtein distance:", distance)
    """
    len1, len2 = len(tuple1), len(tuple2)
    
    # Create a distance matrix and initialize it
    dp = [[0] * (len2 + 1) for _ in range(len1 + 1)]
    
    # Initialize the first row and column of the distance matrix
    for i in range(len1 + 1):
        dp[i][0] = i
    for j in range(len2 + 1):
        dp[0][j] = j
    
    # Populate the distance matrix
    for i in range(1, len1 + 1):
        for j in range(1, len2 + 1):
            if tuple1[i - 1] == tuple2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = min(dp[i - 1][j] + 1,   # Deletion
                               dp[i][j - 1] + 1,   # Insertion
                               dp[i - 1][j - 1] + 1) # Substitution
    
    return dp[len1][len2]


class DictBasedOracle(BaseOracle):

    def __init__(
            self,
            language: BaseLanguage,
            *args: Any,
            **kwargs: Any
    ) -> None:
        super().__init__(language, *args, **kwargs)
        self.ce_history = kwargs['ce_cache']
        for ce in self.ce_history:
            self.query_history[query2str(ce.getInput())] = ce

    def __call__(self, input_str: str, *args: Any, **kwds: Any):
        if input_str in self.query_history:
            return self.query_history[input_str]
        else:
            return False
        
    def check_conjecture(self, aut, _type=...) -> Tuple[Tuple[str]]:
        for ce in self.ce_history:
            if not aut.accepts(ce.getInput()):
                return ce
        return None


class DiscriminatorLLMOracle(BaseLLMOracle):

    def __init__(
            self,
            language: BaseLanguage,
            model_name: str,
            *args: Any,
            **kwargs: Any
    ) -> None:
        super().__init__(language, model_name, *args, **kwargs)
        self.ce_cache = []
        self.tmp_oracle = TemporaryMembershipOracleWrapper(DictBasedOracle(language, ce_cache=self.ce_cache))

    def reset(self, language: BaseLanguage, exp_dir, alphabet=None, load_history=False):
        self.ce_cache = []
        self.tmp_oracle = TemporaryMembershipOracleWrapper(DictBasedOracle(language, ce_cache=self.ce_cache))
        self.core_learner = TTTLearnerDFA(self.tmp_oracle.jalphabet, self.tmp_oracle, AcexAnalyzers.BINARY_SEARCH_BWD)
        self.core_learner.startLearning()
        
        return super().reset(language, exp_dir, alphabet, load_history)
        
    def _get_membership_query_result(self, input_str: str, use_cache: bool=True, seed: int=0, *args: Any, **kwargs: Any) -> Any:
        if input_str in self.llm_resp_cache and use_cache:
            return self.llm_resp_cache[input_str]['answer']
        
        queries = self._construct_existence_message(input_str)
        result = self._get_json_resp(queries, seed=seed)
        ans = get_value_by_key(result, "answer", is_boolean=True)
        rsn = get_value_by_key(result, "reason")
        self.llm_resp_cache[input_str] = {'answer': ans, 'reason': rsn}
        return ans
    
    def _construct_existence_message(self, query: str):
        examples = self._get_related_query_by_discrimination_tree(query)
        if examples is None:
            return [
                {"role": "system", "content": "You are a helpful assistant designed to output JSON. Answer in a consistent style and output the reason first."},
                {"role": "user", "content": f"{self.language.definition}. {self.language.examples['pos']['query']}"},
                {"role": "assistant", "content": self.language.examples['pos']['answer']},
                {"role": "user", "content": f"{self.language.examples['neg']['query']}"},
                {"role": "assistant", "content": self.language.examples['neg']['answer']},
                {"role": "user", "content": f"Given a string \"{query}\", does this string belongs to the language?"}
            ]
        else:
            examples = [(word2tuple(e.getInput()), e.getOutput()) for e in examples]
            return [
                {"role": "system", "content": "You are a helpful assistant designed to output JSON. Answer in a consistent style and output the reason first."},
                {"role": "user", "content": f"{self.language.definition}. Given a string \"{query}\", does this string belongs to the language?"},
                {"role": "assistant", "content": "Yes"},
                {"role": "user", "content": f"Given a string \"{query}\", does this string belongs to the language?"},
                {"role": "assistant", "content": "No"},
                {"role": "user", "content": f"There are two verified examples: {examples[0]}, {examples[1]}. Given a string \"{query}\", does this string belongs to the language?"}
            ]
                
        
    def _get_related_query_by_discrimination_tree(self, query):
        if len(self.ce_cache) < 2:
            return None
        if len(self.ce_cache) == 2:
            return self.ce_cache
        
        ttt_hypothesis = self.core_learner.getHypothesisDS()
        tttstate0 = ttt_hypothesis.getState(tuple2word(query))
        if tttstate0 is None: return None
        tttleaf = tttstate0.getDTLeaf()
        parent = tttleaf.getParent()
        tttstate1 = parent.anySubtreeState()

        return [self.get_closest_ce(query, ttt_hypothesis, tttstate0), self.get_closest_ce(query, ttt_hypothesis, tttstate1)]

    def get_closest_ce(self, query, ttt_hypothesis, tttstate):
        min_dist_ce = None
        ce2tuple = lambda ce: word2tuple(ce.getInput())
        for ce in self.ce_cache:
            ce_state = ttt_hypothesis.getState(ce.getInput())
            if str(tttstate) == str(ce_state):
                if min_dist_ce is None or levenshtein_distance_tuple(ce2tuple(min_dist_ce), query) > levenshtein_distance_tuple(ce2tuple(ce), query):
                    min_dist_ce = ce
        return min_dist_ce
    
    def check_conjecture(self, aut, _type=...):
        ce = super().check_conjecture(aut, _type)
        if ce is not None:
            self.ce_cache.append(ce)
        return ce
